{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.为MNIST数据集构建一个分类器，并在测试集上达成超过90%的精度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\baohuhe\\appdata\\local\\programs\\python\\python37\\lib\\site-packages\\sklearn\\feature_extraction\\text.py:17: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n",
      "  from collections import Mapping, defaultdict\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.datasets import fetch_mldata\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import precision_score, recall_score\n",
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "mnist = fetch_mldata('MNIST original', data_home='./')\n",
    "X,y = mnist['data'], mnist['target']\n",
    "X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[0:60000], y[60000:]\n",
    "shuffle_index = np.random.permutation(60000)\n",
    "X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]\n",
    "y_train_9 = (y_train==9)\n",
    "y_test_9 = (y_test==9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_img1(X,i):\n",
    "    some_digit = X[i]\n",
    "    some_digit_image = some_digit.reshape(28,28)\n",
    "    plt.imshow(some_digit_image, cmap=matplotlib.cm.binary)\n",
    "    plt.show()\n",
    "\n",
    "def show_img2(X,i):\n",
    "    some_digit = X[i]\n",
    "    some_digit_image = some_digit.reshape(28,28)\n",
    "    plt.imshow(some_digit_image, cmap=matplotlib.cm.binary, interpolation='nearest')\n",
    "    #plt.axis('off')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOnklEQVR4nO3db6xU9Z3H8c8XLQ+kPAC5ErTs3i6KWdxEWkfY6IpKswgmBhD/VCJigkGNmlZrIrBRjIl/YmxxH2gjVSxd0QaDriT+wz8khieE0VBESYU1CJQrXDQK1RjEfvfBPTS3eOc3wzkzcwa+71dyMzPnO+ecr+P9cObOb875mbsLwPFvUNkNAGgPwg4EQdiBIAg7EARhB4I4sZ07GzFihHd3d7dzl0Ao27dv1759+2ygWqGwm9lUSf8t6QRJT7r7Q6nnd3d3q1qtFtklgIRKpVKzlvttvJmdIOkxSdMkjZN0jZmNy7s9AK1V5G/2CZK2ufvH7n5Q0h8lTW9OWwCarUjYT5O0s9/jXdmyf2Bm882sambV3t7eArsDUESRsA/0IcD3vnvr7kvdveLula6urgK7A1BEkbDvkjS63+MfSdpdrB0ArVIk7BsknWFmPzazwZJ+Lml1c9oC0Gy5h97c/ZCZ3SrpdfUNvS1z9w+a1hmApio0zu7ur0h6pUm9AGghvi4LBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBBtnbIZ+XzzzTfJ+ptvvlmzdt999yXX3bBhQ66eDrv22muT9XvuuadmbcyYMcl1Bw3iWNRMvJpAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EATj7B3gwIEDyfrVV1+drL/22mu5921mudeVpBUrVuSu9/T0JNcdOXJkrp4wsEJhN7Ptkg5I+k7SIXevNKMpAM3XjCP7xe6+rwnbAdBC/M0OBFE07C5pjZm9a2bzB3qCmc03s6qZVXt7ewvuDkBeRcN+vrv/VNI0SbeY2aQjn+DuS9294u6Vrq6ugrsDkFehsLv77ux2r6QXJU1oRlMAmi932M1siJkNPXxf0hRJm5vVGIDmKvJp/EhJL2bjtCdKetbd8w/4Hse+/PLLZL3eOeFFxtHHjRuXrC9YsCBZr3c+/LZt2466p8NuuOGGZH3KlCnJ+m233ZZ73xHlDru7fyzp7Cb2AqCFGHoDgiDsQBCEHQiCsANBEHYgCE5xbYNnn302WX/55ZcLbf+6666rWbv77ruT6z7wwAPJepGhtXrq/Xe//vrryfqhQ4eS9dtvv/2oezqecWQHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAYZ2+Czz77LFl/7LHHWrr/qVOn1qytXLkyue7TTz+drNe7utDNN9+crJ9++uk1a/PnD3gls7+rN1X1woULk3V3r1m74447kusejziyA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQjLM3wfPPP5+sf/jhh4W2/+ijjybrV111Vc3anXfeWWjfS5YsSdZnz56de9uTJ09O1mfNmpWsr1+/PllftGhRzdo555yTXPfCCy9M1o9FHNmBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjG2RuUOjd6zZo1Ld33zJkzk/VBg2r/m33eeecl1613vnu98egiTj311GR98eLFyfpll12WrB88eLBm7aabbkquW+//6ejRo5P1TlT3yG5my8xsr5lt7rdsuJm9YWZbs9thrW0TQFGNvI3/vaQjL4WyQNJb7n6GpLeyxwA6WN2wu/s7kj4/YvF0Scuz+8slzWhyXwCaLO8HdCPdvUeSsttTaj3RzOabWdXMqr29vTl3B6Coln8a7+5L3b3i7pV6Fy8E0Dp5w77HzEZJUna7t3ktAWiFvGFfLWludn+upJea0w6AVrHU+LEkmdlzki6SNELSHkmLJf2vpJWS/knSDklXuvuRH+J9T6VS8Wq1WrDlcqTmKR87dmyhbdcby3777beT9aFDhxba/7HqhRdeSNavuOKK3Nuut2697yeUpVKpqFqt2kC1ul+qcfdrapR+VqgrAG3F12WBIAg7EARhB4Ig7EAQhB0IglNcO8CZZ56ZrEcdWqvnkksuSdYnTpxYs1bvMtT79+9P1lOnz0rS4MGDk/UycGQHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAYZ+8Al19+edktHJOGDBmSrF9wwQU1a/XG2etdSnrnzp3J+pgxY5L1MnBkB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgGGdv0DPPPNOybXfimOzxYPbs2TVrjzzySBs76Qwc2YEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMbZG7R79+6yWwAKqXtkN7NlZrbXzDb3W3avmf3FzDZmP5e2tk0ARTXyNv73kqYOsHyJu4/Pfl5pblsAmq1u2N39HUmft6EXAC1U5AO6W81sU/Y2f1itJ5nZfDOrmlm1t7e3wO4AFJE37L+VNEbSeEk9kn5d64nuvtTdK+5e6erqyrk7AEXlCru773H379z9b5J+J2lCc9sC0Gy5wm5mo/o9nClpc63nAugMdcfZzew5SRdJGmFmuyQtlnSRmY2X5JK2S7qxhT12hEqlUrP25JNPFtp2tVpN1s8+++xC2wekBsLu7tcMsPipFvQCoIX4uiwQBGEHgiDsQBCEHQiCsANBcIprgyZPntyyba9duzZZnzdvXsv2fSz74osvkvXrr78+97bPOuusZH348OG5t10WjuxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EATj7A068cTaL9VJJ52UXPfrr79O1r/66qtk/dChQ8l6qrfj2a5du5L1TZs25d72hAnp67EMG1bzSmwdiyM7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgQRc4A2h+7u7pq1adOmJdddtWpVsv7SSy8l6z09Pcn66NGjk/Vj1Y4dO5L1WbNm5d72xRdfnKw//PDDubfdqTiyA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQjLMfAz755JNk/VgdZ1+3bl2yXu96+Vu3bs2977vuuitZP/nkk3Nvu1PVPbKb2WgzW2tmW8zsAzP7RbZ8uJm9YWZbs9tj72x+IJBG3sYfkvQrd/9XSf8u6RYzGydpgaS33P0MSW9ljwF0qLphd/ced38vu39A0hZJp0maLml59rTlkma0qkkAxR3VB3Rm1i3pJ5LWSxrp7j1S3z8Ikk6psc58M6uaWbW3t7dYtwByazjsZvZDSask/dLd9ze6nrsvdfeKu1e6urry9AigCRoKu5n9QH1BX+HuL2SL95jZqKw+StLe1rQIoBnqDr2ZmUl6StIWd/9Nv9JqSXMlPZTdps/TPI7NmTMnWa93ims9V155ZbL+6quv1qyNHz++0L7r2b17d7K+dOnSmrUHH3wwue63336brNe7hPfjjz9eszZx4sTkusejRsbZz5c0R9L7ZrYxW7ZIfSFfaWbzJO2QlP6NBFCqumF393WSrEb5Z81tB0Cr8HVZIAjCDgRB2IEgCDsQBGEHguAU1yaYMmVKsn7jjTcm60888USyvmfPntz7v//++5Pr7tu3L1lftmxZsl5vuulPP/00WU8599xzk/WFCxcm6zNmcLpGfxzZgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIc/e27axSqXi1Wm3b/jrFRx99lKxPmjQpWd+79/i8Lki9cfR657tPnjy5me0cFyqViqrV6oBnqXJkB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgOJ+9DcaOHZusFznnG2gUR3YgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCKJu2M1stJmtNbMtZvaBmf0iW36vmf3FzDZmP5e2vl0AeTXypZpDkn7l7u+Z2VBJ75rZG1ltibs/0rr2ADRLI/Oz90jqye4fMLMtkk5rdWMAmuuo/mY3s25JP5G0Plt0q5ltMrNlZjasxjrzzaxqZtXe3t5CzQLIr+Gwm9kPJa2S9Et33y/pt5LGSBqvviP/rwdaz92XunvF3StdXV1NaBlAHg2F3cx+oL6gr3D3FyTJ3fe4+3fu/jdJv5M0oXVtAiiqkU/jTdJTkra4+2/6LR/V72kzJW1ufnsAmqWRT+PPlzRH0vtmtjFbtkjSNWY2XpJL2i4pPS8xgFI18mn8OkkDXYf6lea3A6BV+AYdEARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCHP39u3MrFfSJ/0WjZC0r20NHJ1O7a1T+5LoLa9m9vbP7j7g9d/aGvbv7dys6u6V0hpI6NTeOrUvid7yaldvvI0HgiDsQBBlh31pyftP6dTeOrUvid7yaktvpf7NDqB9yj6yA2gTwg4EUUrYzWyqmf3ZzLaZ2YIyeqjFzLab2fvZNNTVkntZZmZ7zWxzv2XDzewNM9ua3Q44x15JvXXENN6JacZLfe3Knv687X+zm9kJkj6S9J+SdknaIOkad/+wrY3UYGbbJVXcvfQvYJjZJEl/lfQHd/+3bNnDkj5394eyfyiHuftdHdLbvZL+WvY03tlsRaP6TzMuaYak61Xia5fo6yq14XUr48g+QdI2d//Y3Q9K+qOk6SX00fHc/R1Jnx+xeLqk5dn95er7ZWm7Gr11BHfvcff3svsHJB2eZrzU1y7RV1uUEfbTJO3s93iXOmu+d5e0xszeNbP5ZTczgJHu3iP1/fJIOqXkfo5UdxrvdjpimvGOee3yTH9eVBlhH2gqqU4a/zvf3X8qaZqkW7K3q2hMQ9N4t8sA04x3hLzTnxdVRth3SRrd7/GPJO0uoY8Bufvu7HavpBfVeVNR7zk8g252u7fkfv6uk6bxHmiacXXAa1fm9OdlhH2DpDPM7MdmNljSzyWtLqGP7zGzIdkHJzKzIZKmqPOmol4taW52f66kl0rs5R90yjTetaYZV8mvXenTn7t7238kXaq+T+T/T9J/ldFDjb7+RdKfsp8Pyu5N0nPqe1v3rfreEc2TdLKktyRtzW6Hd1Bv/yPpfUmb1BesUSX19h/q+9Nwk6SN2c+lZb92ib7a8rrxdVkgCL5BBwRB2IEgCDsQBGEHgiDsQBCEHQiCsANB/D+Yh2Mh8J/JwAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_img2(X,11)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 6., 1., 9., 7., 3., 5., 6., 4., 7.])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train[10:20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAANZUlEQVR4nO3df6hc9ZnH8c/HHxVJqol7r9lg46ZbBFcWNpVBVrJIlrqSqGD6R6VRaxZk0z8UG1BZzQrJn0G3LSJSSVdpXLpK1QajqFuJIVqQ6jVkTbLXVVeybeo1uSFqLf7WZ/+4x3KNd75zM3NmzujzfsFlZs5zzpyH4X7uOXO+M/friBCAL79jmm4AwGAQdiAJwg4kQdiBJAg7kMRxg9zZyMhILF68eJC7BFLZt2+fDh065JlqPYXd9nJJt0k6VtK/RcTG0vqLFy/W2NhYL7sEUNBqtdrWuj6Nt32spDskrZB0lqRVts/q9vkA9Fcv79nPkfRKRLwaER9Iuk/SJfW0BaBuvYT9NEm/m/Z4f7XsM2yvsT1me2xycrKH3QHoRS9hn+kiwOc+exsRmyKiFRGt0dHRHnYHoBe9hH2/pEXTHn9N0mu9tQOgX3oJ+3OSzrD9ddtfkfRdSVvraQtA3boeeouIj2xfI+k/NTX0dndE7K2tMwC16mmcPSIelfRoTb0A6CM+LgskQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1Ioqcpm23vk/S2pI8lfRQRrTqaAlC/nsJe+fuIOFTD8wDoI07jgSR6DXtI+pXt522vmWkF22tsj9kem5yc7HF3ALrVa9iXRsTZklZIutr2eUeuEBGbIqIVEa3R0dEedwegWz2FPSJeq24PStoi6Zw6mgJQv67DbnuO7a9+el/SBZL21NUYgHr1cjV+gaQttj99nv+IiMdr6QpH5a233mpbu/7664vb3n///V0/tyTNmzevWF+7dm3b2nXXXVfcdu7cucU6jk7XYY+IVyX9TY29AOgjht6AJAg7kARhB5Ig7EAShB1Ioo4vwqDPnn322WJ9+fLlbWtvvPFG3e18xptvvlmsb9iwoW3twQcfLG775JNPFusjIyPFOj6LIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4+xC4/fbbi/X169cX66Wx9E5j0bfddlux3kmnr6m+/vrrbWu7d+8ubrtq1api/ZFHHinWTzjhhGI9G47sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+wD0On76L2Mo3dyzz33FOsrVqzo+rkladeuXcX6rbfe2vVz79y5s1gfHx8v1pcsWdL1vr+MOLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMsw/ATTfdVKz383+7L1iwoG/PLUmXXnppsV76znmncfLDhw8X6zt27CjWGWf/rI5Hdtt32z5oe8+0ZafYfsL2y9Xt/P62CaBXszmN/5mkI6ccuVHStog4Q9K26jGAIdYx7BHxlKQjz6cukbS5ur9Z0sqa+wJQs24v0C2IiAlJqm5Pbbei7TW2x2yPTU5Odrk7AL3q+9X4iNgUEa2IaI2OjvZ7dwDa6DbsB2wvlKTq9mB9LQHoh27DvlXS6ur+akkP1dMOgH7pOM5u+15JyySN2N4vab2kjZJ+YfsqSb+V9J1+NvlFt2jRop62v+OOO4r17du3t631+61Tq9Uq1t97772+7fvDDz/s23N/GXUMe0S0+0/936q5FwB9xMdlgSQIO5AEYQeSIOxAEoQdSIKvuNag0/DSM888U6yffvrpxfoVV1xRrD/88MNta++8805x207DVy+++GKxfsMNNxTrExMTxXovVq9e3Xkl/AlHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2Grz77rvF+ksvvVSsX3vttcX6SSedVKxv3Lixbe2yyy4rbtvpX00/9thjxXo/nXfeecX6vHnzBtTJlwNHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2GhxzTPlv5pw5c4r1++67r1hfsWJFsb506dK2tU7f+V6/fn2x3qQ777yzWD/++OMH1MmXA0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYanHzyycX6BRdcUKxv2bKlWF+5cmWxvnDhwra1AwcOFLeNiGL93HPPLdaPO678K/T000+3rXX6/MHIyEixjqPT8chu+27bB23vmbZsg+3f295V/VzY3zYB9Go2p/E/k7R8huU/jogl1c+j9bYFoG4dwx4RT0k6PIBeAPRRLxforrH9QnWaP7/dSrbX2B6zPTY5OdnD7gD0otuw/0TSNyQtkTQh6YftVoyITRHRiojW6Ohol7sD0Kuuwh4RByLi44j4RNJPJZ1Tb1sA6tZV2G1PH+v5tqQ97dYFMBw6jrPbvlfSMkkjtvdLWi9pme0lkkLSPknf72OPX3h33XVXsX7zzTcX6w888ECxXpoffv78tpdTJEkXXXRRsX722WcX6+vWrSvWS+Psl19+eXFb3vbVq2PYI2LVDIvLv70Ahg4flwWSIOxAEoQdSIKwA0kQdiAJvuI6AJ2GvzrVOw1/Nenxxx/vetszzzyzxk7QCUd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcXYUjY+PF+t79+4dUCfoFUd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcXYUTUxMFOsffPBB18/NlMyDxZEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnB1FO3bs6NtzL1u2rG/Pjc/reGS3vcj2dtvjtvfa/kG1/BTbT9h+ubotz3QAoFGzOY3/SNJ1EfFXkv5W0tW2z5J0o6RtEXGGpG3VYwBDqmPYI2IiInZW99+WNC7pNEmXSNpcrbZZ0sp+NQmgd0d1gc72YknflPQbSQsiYkKa+oMg6dQ226yxPWZ7bHJysrduAXRt1mG3PVfSg5LWRsQfZrtdRGyKiFZEtEZHR7vpEUANZhV228drKug/j4hfVosP2F5Y1RdKOtifFgHUoePQm21LukvSeET8aFppq6TVkjZWtw/1pUM0auvWrU23gJrMZpx9qaTvSdpte1e1bJ2mQv4L21dJ+q2k7/SnRQB16Bj2iPi1JLcpf6vedgD0Cx+XBZIg7EAShB1IgrADSRB2IAm+4oqi999/v+kWUBOO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kMRs5mdfJOkeSX8u6RNJmyLiNtsbJP2TpMlq1XUR8Wi/GkUzbrnllmL9yiuvLNbPP//8trUTTzyxq57QndlMEvGRpOsiYqftr0p63vYTVe3HEfGv/WsPQF1mMz/7hKSJ6v7btsclndbvxgDU66jes9teLOmbkn5TLbrG9gu277Y9v802a2yP2R6bnJycaRUAAzDrsNueK+lBSWsj4g+SfiLpG5KWaOrI/8OZtouITRHRiojW6OhoDS0D6Maswm77eE0F/ecR8UtJiogDEfFxRHwi6aeSzulfmwB61THsti3pLknjEfGjacsXTlvt25L21N8egLrM5mr8Uknfk7Tb9q5q2TpJq2wvkRSS9kn6fl86RKMuvvjiYv3w4cMD6gS9ms3V+F9L8gwlxtSBLxA+QQckQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUjCETG4ndmTkv5v2qIRSYcG1sDRGdbehrUvid66VWdvfxERM/7/t4GG/XM7t8ciotVYAwXD2tuw9iXRW7cG1Run8UAShB1Ioumwb2p4/yXD2tuw9iXRW7cG0luj79kBDE7TR3YAA0LYgSQaCbvt5bb/x/Yrtm9sood2bO+zvdv2LttjDfdyt+2DtvdMW3aK7Sdsv1zdzjjHXkO9bbD9++q122X7woZ6W2R7u+1x23tt/6Ba3uhrV+hrIK/bwN+z2z5W0kuS/kHSfknPSVoVEf890EbasL1PUisiGv8Ahu3zJP1R0j0R8dfVslskHY6IjdUfyvkR8c9D0tsGSX9sehrvaraihdOnGZe0UtI/qsHXrtDXpRrA69bEkf0cSa9ExKsR8YGk+yRd0kAfQy8inpJ05JQrl0jaXN3frKlfloFr09tQiIiJiNhZ3X9b0qfTjDf62hX6Gogmwn6apN9Ne7xfwzXfe0j6le3nba9pupkZLIiICWnql0fSqQ33c6SO03gP0hHTjA/Na9fN9Oe9aiLsM00lNUzjf0sj4mxJKyRdXZ2uYnZmNY33oMwwzfhQ6Hb68141Efb9khZNe/w1Sa810MeMIuK16vagpC0avqmoD3w6g251e7Dhfv5kmKbxnmmacQ3Ba9fk9OdNhP05SWfY/rrtr0j6rqStDfTxObbnVBdOZHuOpAs0fFNRb5W0urq/WtJDDfbyGcMyjXe7acbV8GvX+PTnETHwH0kXauqK/P9K+pcmemjT119K+q/qZ2/TvUm6V1OndR9q6ozoKkl/JmmbpJer21OGqLd/l7Rb0guaCtbChnr7O029NXxB0q7q58KmX7tCXwN53fi4LJAEn6ADkiDsQBKEHUiCsANJEHYgCcIOJEHYgST+H/x+CQXpNQ5uAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_img2(X_train, 13)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  15, 155, 240,\n",
       "        255, 238, 127,  12,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  54, 237,\n",
       "        254, 254, 254, 254, 254, 210,  20,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 109,\n",
       "        237, 254, 235,  90,  90, 231, 254, 254, 187,   7,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         54, 237, 254, 254,  73,   0,   0, 138, 254, 254, 236,  48,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0, 188, 254, 254, 254,  50,   0,   0,  52, 213, 254, 254,\n",
       "        111,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0, 160, 254, 106, 171,  29,   0,   0,   0, 221,\n",
       "        254, 254, 116,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   8, 217, 225,  14, 183,  39,   5,  45,\n",
       "        184, 246, 254, 254,  21,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,  17, 254, 145,   0,  47, 204,\n",
       "        181, 254, 254, 254, 254, 193,   5,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,  15, 244, 233,  72,\n",
       "         13, 110, 239, 254, 254, 254, 254,  41,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 174,\n",
       "        254, 248, 223, 245, 247, 197, 194, 254, 244,  21,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,  17,  50, 134, 142, 125,  50,   0, 133, 254, 153,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 205, 254, 113,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  41, 249,\n",
       "        254, 113,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         66, 254, 254,  93,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0, 129, 254, 254,  29,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0, 206, 254, 254,  29,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0, 213, 254, 254,  29,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0, 213, 254, 254,  29,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 213, 254, 254,\n",
       "         29,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  76,\n",
       "        237, 170,  20,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "          0,   0,   0,   0]], dtype=uint8)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_digit = X_train[13:14]\n",
    "test_digit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_digits = X_train[10:20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([False, False, False, ..., False, False, False])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train_9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
       "           metric_params=None, n_jobs=1, n_neighbors=5, p=2,\n",
       "           weights='uniform')"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_clf = KNeighborsClassifier()\n",
    "knn_clf.fit(X_train, y_train_9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([False, False, False,  True, False, False, False, False, False,\n",
       "       False])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_clf.predict(test_digits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ True])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_clf.predict(test_digit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_score = cross_val_score(knn_clf, X_train, y_train_9, cv=3, scoring='accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9908"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_score.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed: 21.6min remaining:    0.0s\n",
      "[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed: 42.6min remaining:    0.0s\n",
      "[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed: 64.0min finished\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import cross_val_predict\n",
    "y_train_pred = cross_val_predict(knn_clf, X_train, y_train_9, cv=3, verbose=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[53748,   303],\n",
       "       [  249,  5700]], dtype=int64)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix(y_train_9, y_train_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "精度：94.95 %\n",
      "召回率：95.81 %\n"
     ]
    }
   ],
   "source": [
    "print('精度：{0:.2f} %'.format(100*precision_score(y_train_9, y_train_pred)))\n",
    "print('召回率：{0:.2f} %'.format(100*recall_score(y_train_9, y_train_pred)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_test_pred = knn_clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[8948,   43],\n",
       "       [  47,  962]], dtype=int64)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix(y_test_9, y_test_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "精度：95.72 %\n",
      "召回率：95.34 %\n"
     ]
    }
   ],
   "source": [
    "print('精度：{0:.2f} %'.format(100*precision_score(y_test_9, y_test_pred)))\n",
    "print('召回率：{0:.2f} %'.format(100*recall_score(y_test_9, y_test_pred)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 3 folds for each of 6 candidates, totalling 18 fits\n",
      "[CV] n_neighbors=2, weights=uniform ..................................\n",
      "[CV] .... n_neighbors=2, weights=uniform, score=0.98995, total=22.3min\n",
      "[CV] n_neighbors=2, weights=uniform ..................................\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed: 64.5min remaining:    0.0s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV] ...... n_neighbors=2, weights=uniform, score=0.991, total=21.1min\n",
      "[CV] n_neighbors=2, weights=uniform ..................................\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed: 126.7min remaining:    0.0s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV] ..... n_neighbors=2, weights=uniform, score=0.9896, total=20.6min\n",
      "[CV] n_neighbors=2, weights=distance .................................\n",
      "[CV] ...... n_neighbors=2, weights=distance, score=0.99, total=20.6min\n",
      "[CV] n_neighbors=2, weights=distance .................................\n",
      "[CV] ..... n_neighbors=2, weights=distance, score=0.991, total=20.6min\n",
      "[CV] n_neighbors=2, weights=distance .................................\n",
      "[CV] .... n_neighbors=2, weights=distance, score=0.9907, total=20.6min\n",
      "[CV] n_neighbors=4, weights=uniform ..................................\n",
      "[CV] .... n_neighbors=4, weights=uniform, score=0.99165, total=20.6min\n",
      "[CV] n_neighbors=4, weights=uniform ..................................\n",
      "[CV] ..... n_neighbors=4, weights=uniform, score=0.9908, total=20.6min\n",
      "[CV] n_neighbors=4, weights=uniform ..................................\n",
      "[CV] ..... n_neighbors=4, weights=uniform, score=0.9905, total=20.6min\n",
      "[CV] n_neighbors=4, weights=distance .................................\n",
      "[CV] ... n_neighbors=4, weights=distance, score=0.99185, total=20.6min\n",
      "[CV] n_neighbors=4, weights=distance .................................\n",
      "[CV] .... n_neighbors=4, weights=distance, score=0.9919, total=20.9min\n",
      "[CV] n_neighbors=4, weights=distance .................................\n"
     ]
    }
   ],
   "source": [
    "param_grid = [{'weights': [\"uniform\", \"distance\",], 'n_neighbors': [2, 4, 6]}]\n",
    "knn_clf = KNeighborsClassifier()\n",
    "grid_search = GridSearchCV(knn_clf, param_grid, cv=3, verbose=3)\n",
    "grid_search.fit(X_train, y_train_9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_search.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_search.best_estimator_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_search.best_score_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_search.predict(test_digit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_test_pred = grid_search.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "confusion_matrix(y_test_9, y_test_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('精度：{0:.2f} %'.format(100*precision_score(y_test_9, y_test_pred)))\n",
    "print('召回率：{0:.2f} %'.format(100*recall_score(y_test_9, y_test_pred)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*   由于代码运行消耗太长时间，跑了一下午+一晚上，早上一看电脑已关机，最终结果没有出来，请老师酌情给分，谢谢老师"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
